{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

///
/// Experiment to show timing for the matrix multiplication example in the SymForce paper.
///
/// Run with:
///
///     build/bin/benchmarks/matrix_multiplication_benchmark_{{ matrix_name }}
///
/// See run_matmul_benchmarks.py for more information
///

#include <chrono>
#include <thread>

#include <Eigen/Core>
#include <Eigen/SparseCore>
#include <spdlog/spdlog.h>

#include <symforce/opt/tic_toc.h>
#include <symforce/opt/util.h>

#include <catch2/catch_template_test_macros.hpp>
#include <catch2/catch_test_macros.hpp>

using namespace sym;

#include "./compute_a_{{ matrix_name }}.h"
#include "./compute_a_dense_{{ matrix_name }}.h"
#include "./compute_a_dense_dynamic_{{ matrix_name }}.h"
#include "./compute_at_b_{{ matrix_name }}.h"
#include "./compute_b_{{ matrix_name }}.h"
#include "./compute_b_dense_{{ matrix_name }}.h"
#include "./compute_b_dense_dynamic_{{ matrix_name }}.h"

template <typename Scalar>
__attribute__((noinline)) Eigen::Matrix<Scalar, {{ M }}, {{ M }}> ComputeDenseFixed{{ matrix_name_camel }}(
    {% for i in range(n_symbols) %}const Scalar x{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
  ) {
  Eigen::Matrix<Scalar, {{ N }}, {{ M }}> A = ComputeADense{{ matrix_name_camel }}<Scalar>(
      {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
  );
  Eigen::Matrix<Scalar, {{ N }}, {{ M }}> B = ComputeBDense{{ matrix_name_camel }}<Scalar>(
      {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
  );
  return A.transpose() * B;
}

template <typename Scalar>
__attribute__((noinline)) Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>
  ComputeDenseDynamic{{ matrix_name_camel }}(
    {% for i in range(n_symbols) %}const Scalar x{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
  ) {
  const auto A = ComputeADenseDynamic{{ matrix_name_camel }}<Scalar>(
      {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
  );
  const auto B = ComputeBDenseDynamic{{ matrix_name_camel }}<Scalar>(
      {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
  );
  return A.transpose() * B;
}

template <typename Scalar>
__attribute__((noinline)) Eigen::SparseMatrix<Scalar> ComputeSparse{{ matrix_name_camel }}(
    {% for i in range(n_symbols) %}const Scalar x{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
  ) {
  Eigen::SparseMatrix<Scalar> A = ComputeA{{ matrix_name_camel }}<Scalar>(
      {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
  );
  Eigen::SparseMatrix<Scalar> B = ComputeB{{ matrix_name_camel }}<Scalar>(
      {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
  );
  return A.transpose() * B;
}

// ----------------------------------------------------------------------------
// Test Cases
// ----------------------------------------------------------------------------

TEMPLATE_TEST_CASE("sparse_{{ matrix_name }}", "", double, float) {
  using Scalar = TestType;

  fmt::print("n_runs_multiplier: {};\n", {{ n_runs_multiplier }});

  {% for i in range(2, n_symbols) %}
  const Scalar x{{ i }} = {{ i - 1 }}.0;
  {% endfor %}

  std::chrono::milliseconds timespan(100);
  std::this_thread::sleep_for(timespan);

  Scalar sum = 0.0;
  {
    SYM_TIME_SCOPE("sparse_{{ matrix_name }}_{}", typeid(Scalar).name());

    for (Scalar x0 = 0.1; x0 < {{ n_runs_multiplier }}; x0 += 0.1) {
      for (Scalar x1 = 0.1; x1 < {{ n_runs_multiplier }}; x1 += 0.1) {
        auto mat = ComputeSparse{{ matrix_name_camel }}(
          {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
        );
        sum += mat.valuePtr()[0];
      }
    }
  }
}

TEMPLATE_TEST_CASE("dense_dynamic_{{ matrix_name }}", "", double, float) {
  using Scalar = TestType;

  fmt::print("n_runs_multiplier: {};\n", {{ n_runs_multiplier }});

  {% for i in range(2, n_symbols) %}
  const Scalar x{{ i }} = {{ i - 1 }}.0;
  {% endfor %}

  std::chrono::milliseconds timespan(100);
  std::this_thread::sleep_for(timespan);

  Scalar sum = 0.0;
  {
    SYM_TIME_SCOPE("dense_dynamic_{{ matrix_name }}_{}", typeid(Scalar).name());
    for (Scalar x0 = 0.1; x0 < {{ n_runs_multiplier }}; x0 += 0.1) {
      for (Scalar x1 = 0.1; x1 < {{ n_runs_multiplier }}; x1 += 0.1) {
        auto mat = ComputeDenseDynamic{{ matrix_name_camel }}<Scalar>(
          {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
        );
        sum += mat(0, 0);
      }
    }
  }
}

{% if not cant_allocate_on_stack %}
TEMPLATE_TEST_CASE("dense_fixed_{{ matrix_name }}", "", double, float) {
  using Scalar = TestType;

  fmt::print("n_runs_multiplier: {};\n", {{ n_runs_multiplier }});

  {% for i in range(2, n_symbols) %}
  const Scalar x{{ i }} = {{ i - 1 }}.0;
  {% endfor %}

  std::chrono::milliseconds timespan(100);
  std::this_thread::sleep_for(timespan);

  Scalar sum = 0.0;
  {
    SYM_TIME_SCOPE("dense_fixed_{{ matrix_name }}_{}", typeid(Scalar).name());
    for (Scalar x0 = 0.1; x0 < {{ n_runs_multiplier }}; x0 += 0.1) {
      for (Scalar x1 = 0.1; x1 < {{ n_runs_multiplier }}; x1 += 0.1) {
        auto mat = ComputeDenseFixed{{ matrix_name_camel }}<Scalar>(
          {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
        );
        sum += mat(0, 0);
      }
    }
  }
}
{% endif %}

TEMPLATE_TEST_CASE("flattened_{{ matrix_name }}", "", double, float) {
  using Scalar = TestType;

  fmt::print("n_runs_multiplier: {};\n", {{ n_runs_multiplier }});

  {% for i in range(2, n_symbols) %}
  const Scalar x{{ i }} = {{ i - 1 }}.0;
  {% endfor %}

  std::chrono::milliseconds timespan(100);
  std::this_thread::sleep_for(timespan);

  Scalar sum = 0.0;
  {
    SYM_TIME_SCOPE("flattened_{{ matrix_name }}_{}", typeid(Scalar).name());
    for (Scalar x0 = 0.1; x0 < {{ n_runs_multiplier }}; x0 += 0.1) {
      for (Scalar x1 = 0.1; x1 < {{ n_runs_multiplier }}; x1 += 0.1) {
        auto mat = ComputeAtB{{ matrix_name_camel }}(
          {% for i in range(n_symbols) %}x{{ i }}{% if not loop.last %},{% endif %}{% endfor %}
        );
        {% if symforce_result_is_sparse %}
        sum += mat.valuePtr()[0];
        {% else %}
        sum += mat(0, 0);
        {% endif %}
      }
    }
  }
}
